package com.kool.kmqtt.server.session;

import com.alibaba.fastjson.JSON;
import com.kool.kmqtt.server.PacketSender;
import com.kool.kmqtt.server.constant.PacketTypeEnum;
import com.kool.kmqtt.server.encoder.PublishPacketEncoder;
import com.kool.kmqtt.server.packet.*;
import com.kool.kmqtt.server.repository.Repository;
import com.kool.kmqtt.server.repository.RepositoryFactory;
import com.kool.kmqtt.server.repository.subscription.Subscriber;
import lombok.extern.slf4j.Slf4j;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;

/**
 * 会话缓存
 */
@Slf4j
public class SessionHolder {
    private static SessionHolder sessionHolder = new SessionHolder();

    /**
     * 会话上下文缓存
     */
    private ConcurrentHashMap<String, SessionContext> sessionMap = new ConcurrentHashMap<>();
    /**
     * key=clientId
     * value=sessionId
     */
    private ConcurrentHashMap<String, String> clientIdSessionId = new ConcurrentHashMap<>();

    /**
     * 获取实例
     *
     * @return
     */
    public static SessionHolder getInstance() {
        return sessionHolder;
    }

    public void put(String sessionId, SessionContext sessionContext) {
        if (sessionId == null) {
            throw new RuntimeException("会话id为空");
        }
        if (sessionContext == null) {
            throw new RuntimeException("会话上下文为空");
        }
        sessionMap.put(sessionId, sessionContext);
        if (sessionContext.getClientId() != null) {
            clientIdSessionId.put(sessionContext.getClientId(), sessionId);
        }
    }

    public SessionContext getByClientId(String clientId) {
        if (clientId == null) {
            return null;
        }
        String sessionId = clientIdSessionId.get(clientId);
        if (sessionId == null) {
            return null;
        }
        return sessionMap.get(sessionId);
    }

    public SessionContext getBySessionId(String sessionId) {
        if (sessionId == null) {
            return null;
        }
        return sessionMap.get(sessionId);
    }

    public void remove(String sessionId) {
        if (sessionId == null) {
            return;
        }
        SessionContext sessionContext = getBySessionId(sessionId);
        if (sessionContext == null) {
            return;
        }
        String clientId = sessionContext.getClientId();
        sessionMap.remove(sessionId);
        if (clientId != null) {
            clientIdSessionId.remove(clientId);
        }
    }

    public List<SessionContext> getAllCopy() {
        return JSON.parseArray(JSON.toJSONString(new ArrayList<>(sessionMap.values())), SessionContext.class);
    }

    /**
     * 统计会话数
     *
     * @return
     */
    public int countSession() {
        return sessionMap.size();
    }

    /**
     * 断开连接
     *
     * @param sessionContext
     */
    public static void close(SessionContext sessionContext) {
        Repository repository = RepositoryFactory.getRepository();
        String clientId = sessionContext.getClientId();
        log.info("服务端主动断开客户端[{}]的连接", clientId);
        if (clientId != null) {
            //获取遗嘱信息
            WillStatus willStatus = repository.getWillStatus(clientId);
            //遗嘱标志==1时，处理遗嘱
            if (willStatus != null && willStatus.isWillFlag()) {
                boolean willRetain = willStatus.isWillRetain();
                int willQoS = willStatus.getWillQoS();
                String willTopic = willStatus.getWillTopic();
                String willMessage = willStatus.getWillMessage();
                if (willRetain) {
                    //如果是保留消息，保存
                    Packet packet = new Packet();
                    FixedHeader fixedHeader = new FixedHeader();
                    fixedHeader.setPacketType(PacketTypeEnum.PUBLISH.getCode());
                    fixedHeader.setDup(false);
                    fixedHeader.setQoS(willQoS);
                    fixedHeader.setRetain(true);
                    PublishVariableHeader variableHeader = new PublishVariableHeader();
                    variableHeader.setTopicName(willTopic);
                    PublishPayload payload = new PublishPayload();
                    payload.setPayload(willMessage.getBytes());
                    packet.setFixedHeader(fixedHeader);
                    packet.setVariableHeader(variableHeader);
                    packet.setPayload(payload);
                    //保存保留消息
                    repository.saveRetainPacket(willTopic, packet);
                } else {
                    //发布遗嘱消息
                    sendWillMessage(willQoS, willTopic, willMessage);
                }

            }

        }

        //删除缓存中的会话上下文
        SessionHolder.getInstance().remove(sessionContext.getSessionId());
        if (sessionContext.getCleanSession() != null
                && sessionContext.getCleanSession()
                && sessionContext.getClientId() != null) {
            //删除会话状态
            RepositoryFactory.getRepository().deleteSessionStatus(sessionContext.getClientId());
        }
        if (sessionContext.getCtx() != null
                && sessionContext.getCtx().get() != null) {
            //断开网络连接
            sessionContext.getCtx().get().close();
        }
    }

    private static void sendWillMessage(int willQoS, String willTopic, String willMessage) {
        //查询topic的订阅者
        List<Subscriber> subscribers = RepositoryFactory.getRepository().getSubscriber(willTopic);
        if (subscribers != null) {
            for (Subscriber subscriber : subscribers) {
                String clientId = subscriber.getClientId();
                SessionContext sessionContext = SessionHolder.getInstance().getByClientId(clientId);
                if (sessionContext != null) {
                    Packet packet = new Packet();
                    FixedHeader fixedHeader = new FixedHeader();
                    fixedHeader.setPacketType(PacketTypeEnum.PUBLISH.getCode());
                    fixedHeader.setDup(false);
                    fixedHeader.setQoS(willQoS);
                    fixedHeader.setRetain(false);
                    PublishVariableHeader variableHeader = new PublishVariableHeader();
                    variableHeader.setTopicName(willTopic);
                    if (willQoS == 1 || willQoS == 2) {
                        //使用1个报文id
                        variableHeader.setPacketId(PacketIdGenerator.generateId(clientId));
                        packet.setPacketId(variableHeader.getPacketId());
                    }
                    PublishPayload payload = new PublishPayload();
                    payload.setPayload(willMessage.getBytes());
                    packet.setFixedHeader(fixedHeader);
                    packet.setVariableHeader(variableHeader);
                    packet.setPayload(payload);

                    PacketSender packetSender = new PacketSender(sessionContext, new PublishPacketEncoder());
                    packetSender.send(packet);
                }
            }
        }
    }
}
